"""This code was originally taken from https://github.com/google/prompt-to-
prompt."""

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple, Union

import cv2
import numpy as np
import torch
from IPython.display import display
from PIL import Image
from tqdm import tqdm


def tensor_to_nparray(image: torch.Tensor):
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()[0]
    image = (image * 255).astype(np.uint8)
    return image


def text_under_image(image: np.ndarray,
                     text: str,
                     text_color: Tuple[int, int, int] = (0, 0, 0)):
    h, w, c = image.shape
    offset = int(h * .2)
    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
    font = cv2.FONT_HERSHEY_SIMPLEX
    img[:h] = image
    textsize = cv2.getTextSize(text, font, 1, 2)[0]
    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
    cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
    return img


def view_images(images, num_rows=1, offset_ratio=0.02):
    if type(images) is list:
        num_empty = len(images) % num_rows
    elif images.ndim == 4:
        num_empty = images.shape[0] % num_rows
    else:
        images = [images]
        num_empty = 0

    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
    images = [image.astype(np.uint8)
              for image in images] + [empty_images] * num_empty
    num_items = len(images)

    h, w, c = images[0].shape
    offset = int(h * offset_ratio)
    num_cols = num_items // num_rows
    image_ = np.ones(
        (h * num_rows + offset * (num_rows - 1), w * num_cols + offset *
         (num_cols - 1), 3),
        dtype=np.uint8) * 255
    for i in range(num_rows):
        for j in range(num_cols):
            image_[i * (h + offset):i * (h + offset) + h:, j * (w + offset):j *
                   (w + offset) + w] = images[i * num_cols + j]

    pil_img = Image.fromarray(image_)
    display(pil_img)
    # return pil_img


def diffusion_step(model,
                   controller,
                   latents,
                   context,
                   t,
                   guidance_scale,
                   low_resource=False):
    if low_resource:
        noise_pred_uncond = model.unet(
            latents, t, encoder_hidden_states=context[0])['sample']
        noise_prediction_text = model.unet(
            latents, t, encoder_hidden_states=context[1])['sample']
    else:
        latents_input = torch.cat([latents] * 2)
        noise_pred = model.unet(
            latents_input, t, encoder_hidden_states=context)['sample']
        noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (
        noise_prediction_text - noise_pred_uncond)
    latents = model.scheduler.step(noise_pred, t, latents)['prev_sample']
    latents = controller.step_callback(latents)
    return latents


def latent2image(vae, latents):
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents)['sample']
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).astype(np.uint8)
    return image


def init_latent(latent, model, height, width, generator, batch_size):
    if latent is None:
        latent = torch.randn(
            (1, model.unet.in_channels, height // 8, width // 8),
            generator=generator,
        )
    latents = latent.expand(batch_size, model.unet.in_channels, height // 8,
                            width // 8).to(model.device)
    return latent, latents


@torch.no_grad()
def text2image_ldm_stable(model,
                          prompt: List[str],
                          controller,
                          num_inference_steps: int = 50,
                          guidance_scale: Optional[float] = 7.5,
                          generator: Optional[torch.Generator] = None,
                          latent: Optional[torch.FloatTensor] = None,
                          uncond_embeddings=None,
                          start_time=50,
                          return_type='image'):
    batch_size = len(prompt)
    register_attention_control(model, controller)
    height = width = 512

    text_input = model.tokenizer(
        prompt,
        padding='max_length',
        max_length=model.tokenizer.model_max_length,
        truncation=True,
        return_tensors='pt',
    )
    text_embeddings = model.text_encoder(
        text_input.input_ids.to(model.device))[0]
    max_length = text_input.input_ids.shape[-1]
    if uncond_embeddings is None:
        uncond_input = model.tokenizer(
            [''] * batch_size,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt')
        uncond_embeddings_ = model.text_encoder(
            uncond_input.input_ids.to(model.device))[0]
    else:
        uncond_embeddings_ = None

    latent, latents = init_latent(latent, model, height, width, generator,
                                  batch_size)
    model.scheduler.set_timesteps(num_inference_steps)
    for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
        if uncond_embeddings_ is None:
            context = torch.cat([
                uncond_embeddings[i].expand(*text_embeddings.shape),
                text_embeddings
            ])
        else:
            context = torch.cat([uncond_embeddings_, text_embeddings])
        latents = diffusion_step(
            model,
            controller,
            latents,
            context,
            t,
            guidance_scale,
            low_resource=False)

    if return_type == 'image':
        image = latent2image(model.vae, latents)
    else:
        image = latents
    return image, latent


def register_attention_control(model, controller):

    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward(x, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads
            q = self.to_q(x)
            is_cross = encoder_hidden_states is not None
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            k = self.to_k(encoder_hidden_states)
            v = self.to_v(encoder_hidden_states)
            q = self.head_to_batch_dim(q)
            k = self.head_to_batch_dim(k)
            v = self.head_to_batch_dim(v)

            sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale

            if attention_mask is not None:
                attention_mask = attention_mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
                sim.masked_fill_(~attention_mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            attn = controller(attn, is_cross, place_in_unet)
            out = torch.einsum('b i j, b j d -> b i d', attn, v)
            out = self.batch_to_head_dim(out)
            return to_out(out)

        return forward

    class DummyController:

        def __call__(self, *args):
            return args[0]

        def __init__(self):
            self.num_att_layers = 0

    if controller is None:
        controller = DummyController()

    # def register_recr(net_, count, place_in_unet):
    #     if net_.__class__.__name__ == 'CrossAttention':
    #         net_.forward = ca_forward(net_, place_in_unet)
    #         return count + 1
    #     elif hasattr(net_, 'children'):
    #         for net__ in net_.children():
    #             count = register_recr(net__, count, place_in_unet)
    #     return count

    # cross_att_count = 0
    # sub_nets = model.unet.named_children()
    # for net in sub_nets:
    #     if "down" in net[0]:
    #         cross_att_count += register_recr(net[1], 0, "down")
    #     elif "up" in net[0]:
    #         cross_att_count += register_recr(net[1], 0, "up")
    #     elif "mid" in net[0]:
    #         cross_att_count += register_recr(net[1], 0, "mid")

    # controller.num_att_layers = cross_att_count

    def register_recr(net_name: str, net_, count, place_in_unet):
        if net_name.endswith('attn2') or net_name.endswith('attn1'):
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        return count

    cross_att_count = 0
    sub_nets = model.unet.named_modules()
    for net in sub_nets:
        if 'down' in net[0]:
            cross_att_count += register_recr(net[0], net[1], 0, 'down')
        elif 'up' in net[0]:
            cross_att_count += register_recr(net[0], net[1], 0, 'up')
        elif 'mid' in net[0]:
            cross_att_count += register_recr(net[0], net[1], 0, 'mid')

    controller.num_att_layers = cross_att_count


def get_word_inds(text: str, word_place: int, tokenizer):
    split_text = text.split(' ')
    if type(word_place) is str:
        word_place = [
            i for i, word in enumerate(split_text) if word_place == word
        ]
    elif type(word_place) is int:
        word_place = [word_place]
    out = []
    if len(word_place) > 0:
        words_encode = [
            tokenizer.decode([item]).strip('#')
            for item in tokenizer.encode(text)['input_ids']
        ][1:-1]
        cur_len, ptr = 0, 0

        for i in range(len(words_encode)):
            cur_len += len(words_encode[i])
            if ptr in word_place:
                out.append(i + 1)
            if cur_len >= len(split_text[ptr]):
                ptr += 1
                cur_len = 0
    return np.array(out)


def update_alpha_time_word(alpha,
                           bounds: Union[float, Tuple[float, float]],
                           prompt_ind: int,
                           word_inds: Optional[torch.Tensor] = None):
    if type(bounds) is float:
        bounds = 0, bounds
    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] *
                                                      alpha.shape[0])
    if word_inds is None:
        word_inds = torch.arange(alpha.shape[2])
    alpha[:start, prompt_ind, word_inds] = 0
    alpha[start:end, prompt_ind, word_inds] = 1
    alpha[end:, prompt_ind, word_inds] = 0
    return alpha


def get_time_words_attention_alpha(
        prompts,
        num_steps,
        cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
        tokenizer,
        max_num_words=77):
    if type(cross_replace_steps) is not dict:
        cross_replace_steps = {'default_': cross_replace_steps}
    if 'default_' not in cross_replace_steps:
        cross_replace_steps['default_'] = (0., 1.)
    alpha_time_words = torch.zeros(num_steps + 1,
                                   len(prompts) - 1, max_num_words)
    for i in range(len(prompts) - 1):
        alpha_time_words = update_alpha_time_word(
            alpha_time_words, cross_replace_steps['default_'], i)
    for key, item in cross_replace_steps.items():
        if key != 'default_':
            inds = [
                get_word_inds(prompts[i], key, tokenizer)
                for i in range(1, len(prompts))
            ]
            for i, ind in enumerate(inds):
                if len(ind) > 0:
                    alpha_time_words = update_alpha_time_word(
                        alpha_time_words, item, i, ind)
    alpha_time_words = alpha_time_words.reshape(num_steps + 1,
                                                len(prompts) - 1, 1, 1,
                                                max_num_words)
    return alpha_time_words


def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
    if type(image_path) is str:
        image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
    else:
        image = image_path
    h, w, c = image.shape
    left = min(left, w - 1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h - bottom, left:w - right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = np.array(Image.fromarray(image).resize((512, 512)))
    image = torch.from_numpy(image).float() / 127.5 - 1
    image = image.permute(2, 0, 1).unsqueeze(0).to(device)

    return image
